50 research outputs found
Sliceformer: Make Multi-head Attention as Simple as Sorting in Discriminative Tasks
As one of the most popular neural network modules, Transformer plays a
central role in many fundamental deep learning models, e.g., the ViT in
computer vision and the BERT and GPT in natural language processing. The
effectiveness of the Transformer is often attributed to its multi-head
attention (MHA) mechanism. In this study, we discuss the limitations of MHA,
including the high computational complexity due to its ``query-key-value''
architecture and the numerical issue caused by its softmax operation.
Considering the above problems and the recent development tendency of the
attention layer, we propose an effective and efficient surrogate of the
Transformer, called Sliceformer. Our Sliceformer replaces the classic MHA
mechanism with an extremely simple ``slicing-sorting'' operation, i.e.,
projecting inputs linearly to a latent space and sorting them along different
feature dimensions (or equivalently, called channels). For each feature
dimension, the sorting operation implicitly generates an implicit attention map
with sparse, full-rank, and doubly-stochastic structures. We consider different
implementations of the slicing-sorting operation and analyze their impacts on
the Sliceformer. We test the Sliceformer in the Long-Range Arena benchmark,
image classification, text classification, and molecular property prediction,
demonstrating its advantage in computational complexity and universal
effectiveness in discriminative tasks. Our Sliceformer achieves comparable or
better performance with lower memory cost and faster speed than the Transformer
and its variants. Moreover, the experimental results reveal that applying our
Sliceformer can empirically suppress the risk of mode collapse when
representing data. The code is available at
\url{https://github.com/SDS-Lab/sliceformer}
A Quasi-Wasserstein Loss for Learning Graph Neural Networks
When learning graph neural networks (GNNs) in node-level prediction tasks,
most existing loss functions are applied for each node independently, even if
node embeddings and their labels are non-i.i.d. because of their graph
structures. To eliminate such inconsistency, in this study we propose a novel
Quasi-Wasserstein (QW) loss with the help of the optimal transport defined on
graphs, leading to new learning and prediction paradigms of GNNs. In
particular, we design a "Quasi-Wasserstein" distance between the observed
multi-dimensional node labels and their estimations, optimizing the label
transport defined on graph edges. The estimations are parameterized by a GNN in
which the optimal label transport may determine the graph edge weights
optionally. By reformulating the strict constraint of the label transport to a
Bregman divergence-based regularizer, we obtain the proposed Quasi-Wasserstein
loss associated with two efficient solvers learning the GNN together with
optimal label transport. When predicting node labels, our model combines the
output of the GNN with the residual component provided by the optimal label
transport, leading to a new transductive prediction paradigm. Experiments show
that the proposed QW loss applies to various GNNs and helps to improve their
performance in node-level classification and regression tasks
Regularized Optimal Transport Layers for Generalized Global Pooling Operations
Global pooling is one of the most significant operations in many machine
learning models and tasks, which works for information fusion and structured
data (like sets and graphs) representation. However, without solid mathematical
fundamentals, its practical implementations often depend on empirical
mechanisms and thus lead to sub-optimal, even unsatisfactory performance. In
this work, we develop a novel and generalized global pooling framework through
the lens of optimal transport. The proposed framework is interpretable from the
perspective of expectation-maximization. Essentially, it aims at learning an
optimal transport across sample indices and feature dimensions, making the
corresponding pooling operation maximize the conditional expectation of input
data. We demonstrate that most existing pooling methods are equivalent to
solving a regularized optimal transport (ROT) problem with different
specializations, and more sophisticated pooling operations can be implemented
by hierarchically solving multiple ROT problems. Making the parameters of the
ROT problem learnable, we develop a family of regularized optimal transport
pooling (ROTP) layers. We implement the ROTP layers as a new kind of deep
implicit layer. Their model architectures correspond to different optimization
algorithms. We test our ROTP layers in several representative set-level machine
learning scenarios, including multi-instance learning (MIL), graph
classification, graph set representation, and image classification.
Experimental results show that applying our ROTP layers can reduce the
difficulty of the design and selection of global pooling -- our ROTP layers may
either imitate some existing global pooling methods or lead to some new pooling
layers fitting data better. The code is available at
\url{https://github.com/SDS-Lab/ROT-Pooling}